#include <cmath>
#include <algorithm>
#include <stdio.h>
#include "roi_extractor.h"
#include "amir_cuda_util/cuda_util.h"


namespace amirstan
{
namespace plugin
{
    using namespace amirstan::cuda;
    const int kMAX_FEATMAP_SIZE=10;
    struct FeatData{
        const void* data[kMAX_FEATMAP_SIZE];
        int batch_size;
        int channels;
        int h[kMAX_FEATMAP_SIZE];
        int w[kMAX_FEATMAP_SIZE];
        float spatial_scale[kMAX_FEATMAP_SIZE];
        int num_featmap;
    };

    template <typename scalar_t>
    __device__ scalar_t bilinear_interpolate(const scalar_t *bottom_data,
                                             const int height, const int width,
                                             scalar_t y, scalar_t x) {
      // deal with cases that inverse elements are out of feature map boundary
      if (y < -1.0 || y > height || x < -1.0 || x > width) {
        return 0;
      }
    
      if (y <= 0) y = 0;
      if (x <= 0) x = 0;
    
      int y_low = (int)y;
      int x_low = (int)x;
      int y_high;
      int x_high;
    
      if (y_low >= height - 1) {
        y_high = y_low = height - 1;
        y = (scalar_t)y_low;
      } else {
        y_high = y_low + 1;
      }
    
      if (x_low >= width - 1) {
        x_high = x_low = width - 1;
        x = (scalar_t)x_low;
      } else {
        x_high = x_low + 1;
      }
    
      scalar_t ly = y - y_low;
      scalar_t lx = x - x_low;
      scalar_t hy = 1. - ly;
      scalar_t hx = 1. - lx;
      // do bilinear interpolation
      scalar_t lt = bottom_data[y_low * width + x_low];
      scalar_t rt = bottom_data[y_low * width + x_high];
      scalar_t lb = bottom_data[y_high * width + x_low];
      scalar_t rb = bottom_data[y_high * width + x_high];
      scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
    
      scalar_t val = (w1 * lt + w2 * rt + w3 * lb + w4 * rb);
    
      return val;
    }

    template <typename scalar_t>
    __device__ scalar_t roi_align_single(const scalar_t *bottom_data,
                                    const int roi_batch_ind,
                                    const scalar_t roi_start_w,
                                    const scalar_t roi_start_h,
                                    const scalar_t roi_end_w,
                                    const scalar_t roi_end_h,
                                    const scalar_t spatial_scale,
                                    const int pw, const int ph, const int c,
                                    const int sample_num, const int channels,
                                    const int height, const int width,
                                    const int pooled_height, const int pooled_width,
                                    const bool aligned) {
    
        // Force malformed ROIs to be 1x1
        scalar_t roi_width = fmaxf((scalar_t)roi_end_w - (scalar_t)roi_start_w, 0.);
        scalar_t roi_height = fmaxf((scalar_t)roi_end_h - (scalar_t)roi_start_h, 0.);
        if (!aligned) {
          roi_width = max(roi_width, (scalar_t)1.);
          roi_height = max(roi_height, (scalar_t)1.);
        }
    
        const scalar_t bin_size_h = roi_height / pooled_height;
        const scalar_t bin_size_w = roi_width / pooled_width;

        
        const scalar_t *offset_bottom_data =
            bottom_data + (roi_batch_ind * channels + c) * height * width;
    
        int sample_num_h = (sample_num > 0)
                               ? sample_num
                               : ceil(roi_height / pooled_height);  // e.g., = 2
        int sample_num_w =
            (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width);
    
        scalar_t output_val = 0;
        #pragma unroll
        for (int iy = 0; iy < sample_num_h; iy++) {
          const scalar_t y = roi_start_h + ph * bin_size_h +
                             (scalar_t)(iy + scalar_t(.5f)) * bin_size_h /
                                 (scalar_t)(sample_num_h);
          #pragma unroll
          for (int ix = 0; ix < sample_num_w; ix++) {
            const scalar_t x = roi_start_w + pw * bin_size_w +
                               (scalar_t)(ix + scalar_t(.5f)) * bin_size_w /
                                   (scalar_t)(sample_num_w);
            scalar_t val = bilinear_interpolate<scalar_t>(offset_bottom_data,
                                                          height, width, y, x);
            output_val += val;
          }
        }
        output_val /= max(sample_num_h * sample_num_w, 1);
        
        return output_val;
    }

    template<typename scalar_t>
    __global__ void roi_extractor_kernel(
        scalar_t* output, 
        const scalar_t *bottom_rois,
        FeatData feat_data,
        const int sample_num, const float roi_scale_factor, const int finest_scale,
        const int pooled_height, const int pooled_width,
        const bool aligned, 
        int nThreads){
        CUDA_KERNEL_LOOP(index, nThreads){
            const int channels = feat_data.channels;
            const int pw = index % pooled_width;
            const int ph = (index / pooled_width) % pooled_height;
            const int c = (index / pooled_width / pooled_height) % channels;
            const int n = index / pooled_width / pooled_height / channels;

            const scalar_t *offset_bottom_rois = bottom_rois + n * 5;

            scalar_t roi_offset_x0 = offset_bottom_rois[1];
            scalar_t roi_offset_y0 = offset_bottom_rois[2];
            scalar_t roi_offset_x1 = offset_bottom_rois[3];
            scalar_t roi_offset_y1 = offset_bottom_rois[4];

            const scalar_t scale = sqrtf((roi_offset_y1 - roi_offset_y0 + 1.)*(roi_offset_x1 - roi_offset_x0 + 1.));

            const int target_lvls = fminf(feat_data.num_featmap-1, fmaxf(0,floorf(log2f(scale/(scalar_t)(finest_scale)+1e-6))));
            
            if(roi_scale_factor>0.){
              const scalar_t roi_off_cx = (roi_offset_x0+roi_offset_x1)*0.5;
              const scalar_t roi_off_cy = (roi_offset_y0+roi_offset_y1)*0.5;
              const scalar_t roi_off_w = (roi_offset_x1-roi_offset_x0 +1)*roi_scale_factor;
              const scalar_t roi_off_h = (roi_offset_y1-roi_offset_y0 +1)*roi_scale_factor;

              roi_offset_x0 = roi_off_cx - roi_off_w*0.5 + 0.5;
              roi_offset_x1 = roi_off_cx + roi_off_w*0.5 - 0.5;
              roi_offset_y0 = roi_off_cy - roi_off_h*0.5 + 0.5;
              roi_offset_y1 = roi_off_cy + roi_off_h*0.5 - 0.5;
            }

            const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls];
            const int height = feat_data.h[target_lvls];
            const int width = feat_data.w[target_lvls];
            const scalar_t *bottom_data = (scalar_t*)feat_data.data[target_lvls];

            const int roi_batch_ind = offset_bottom_rois[0];
            const scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0;
            const scalar_t roi_start_w = roi_offset_x0 * spatial_scale - offset;
            const scalar_t roi_start_h = roi_offset_y0 * spatial_scale - offset;
            const scalar_t roi_end_w = (roi_offset_x1) * spatial_scale - offset;
            const scalar_t roi_end_h = (roi_offset_y1) * spatial_scale - offset;

            const scalar_t output_val = roi_align_single<scalar_t>(bottom_data,
                                        roi_batch_ind,
                                        roi_start_w,
                                        roi_start_h,
                                        roi_end_w,
                                        roi_end_h,
                                        spatial_scale,
                                        pw, ph, c,
                                        sample_num, channels,
                                        height, width,
                                        pooled_height, pooled_width,
                                        aligned);

            output[index] = output_val;
        }
    }
    
    template<typename T>
    void roi_extractor(T* output, 
                        const T* rois, 
                        int num_rois,
                        const void *const *feats, 
                        int num_feats,
                        int n,
                        int c,
                        int *h,
                        int *w,
                        int *strides,
                        int out_size,
                        int sample_num,
                        float roi_scale_factor,
                        int finest_scale,
                        cudaStream_t stream){
        FeatData feat_data;
        feat_data.batch_size = n;
        feat_data.channels = c;
        feat_data.num_featmap = num_feats;
        for(int i=0;i< num_feats;++i){
            feat_data.data[i] = feats[i];
            feat_data.h[i] = h[i];
            feat_data.w[i] = w[i];
            feat_data.spatial_scale[i] = 1./float(strides[i]);
        }
        int pooled_height = out_size;
        int pooled_width = out_size;
        int nThreads = num_rois * c * pooled_height * pooled_width;
        bool aligned = true;
        roi_extractor_kernel<T><<<GET_BLOCKS(nThreads), CUDA_NUM_THREADS,0,stream>>>(
            output, rois,
            feat_data,
            sample_num, roi_scale_factor, finest_scale, 
            pooled_height, pooled_width,
            aligned,
            nThreads);
    }

    template void roi_extractor<float>(float* output, 
                        const float* rois, 
                        int num_rois,
                        const void *const *feats, 
                        int num_feats,
                        int n,
                        int c,
                        int *h,
                        int *w,
                        int *strides,
                        int out_size,
                        int sample_num,
                        float roi_scale_factor,
                        int finest_scale,
                        cudaStream_t stream);


}
}